---
title: "R Notebook"
output: html_notebook
---

# Simulating one-armed Thompson sampling in population
Mean reward from arm 0 is normalized to 0. The distribution of outcomes in arm 1 is normal with unknown mean $\mu$ and known variance $\sigma^2$. 

```{r}
#mu is unknown mean
#sigma is standard error
#N is sample size

#Output:
#x: cumulative sum of outcomes
#q: number of times unknown arm was sampled

TS_one_arm <- function(mu, theta_0, tau, N) {
  
  sigma = sqrt(theta_0*(1-theta_0))
  errors <- rnorm(N, mean = 0, sd = sigma) # Pre-generate all errors
  X <- (mu/sqrt(N)) + errors[1]  # Initialize cumulative sum of outcomes
  Q <- 1  # Initialize the number of times the arm was sampled
  
  for (iter in 2:N) {
    mu_post <- X / (Q + tau)
    sd_post <- sigma / sqrt(Q + tau)
    
    pi <- pnorm(mu_post / sd_post)
    A <- rbinom(1, 1, pi)
    
    if (A == 1) {
      X <- X + (mu/sqrt(N)) + errors[iter]
      Q <- Q + 1
    }
  }
  
  return(c(X / sqrt(N), Q / N))
}
```

## Example
```{r}
mu = 0
theta_0 = 0.05
tau = 0
N = 50000
TS_one_arm(0, theta_0, tau, N)
```


## Computing critical values for weighted average power using limit experiment
```{r}
WAP_critical_value <- function(theta_0, N, range, M, B, test_size, alpha_0, beta_0) {
  
  sigma_0 = sqrt(theta_0*(1-theta_0))
  #tau = ((alpha_0 + beta_0 + 1)*(alpha_0 + beta_0)^2)/(alpha_0*beta_0*N)
  tau = 0
  mu_grid <- seq(from = -range, to = range, length.out = M)
  bandit_outcomes <- replicate(B, TS_one_arm(0, theta_0, tau, N))
  
  weight_function <- function(z) {
    exp_term <- (mu_grid * z[1] - (z[2] * mu_grid^2 / 2))/sigma_0^2
    return(mean(exp(exp_term)))
  }
  
  weighted_LR_draws <- apply(bandit_outcomes, 2, weight_function)
  critical_value <- quantile(weighted_LR_draws, 1 - test_size)
  
  return(critical_value)
}

alpha_0 = 1
beta_0 = 20
theta_0 = 0.05
N = 25000
M = 2500
B = 20000
range = 1.5
test_size = 0.05

asymptotic_critical_value = WAP_critical_value(theta_0, N, range, M, B, test_size, alpha_0, beta_0)
```

```{r}
#Computing weighted average power in limit experiment
WAP_asymptotic_power <- function(h,
                                 N,
                                 theta_0,
                                 range,
                                 M,
                                 B,
                                 alpha_0,
                                 beta_0,
                                 asymptotic_critical_value) {
  
  sigma_0 = sqrt(theta_0*(1-theta_0))
  tau = ((alpha_0 + beta_0 + 1)*(alpha_0 + beta_0)^2)/(alpha_0*beta_0*N)
  mu_grid <- seq(from = -range, to = range, length.out = M)
  bandit_outcomes <- replicate(B, TS_one_arm(h, theta_0, tau, N))
  
  weight_function <- function(z) {
    exp_term <- (mu_grid * z[1] - (z[2] * mu_grid^2 / 2))/sigma_0^2
    return(mean(exp(exp_term)))
  }
  
  weighted_LR_draws <- apply(bandit_outcomes, 2, weight_function)
  power <- mean(weighted_LR_draws >= asymptotic_critical_value)
  
  return(power)
}

WAP_asymptotic_power = Vectorize(WAP_asymptotic_power)
```

```{r}
theta_0 = 0.05
alpha_0 = 1
beta_0 = 20
range = 1.5
N = 25000
M = 2500
B = 10000

seq1 = seq(from = -range, to = range, length.out = 25)
power_asymptotic = WAP_asymptotic_power(seq1,
                                 N,
                                 theta_0,
                                 range,
                                 M,
                                 B,
                                 alpha_0,
                                 beta_0,
                                 asymptotic_critical_value)
```

------------------------------------------------------------
# Simulating TS in bernoulli setting 

```{r}
TS_one_arm_bernoulli = function(theta, theta_0, alpha_0, beta_0, n){

alpha = alpha_0
beta = beta_0
X = 0
Q = 0

reward_vec = rbinom(n, 1, theta)
  
for (i in 1:n) {
  # obtain a random from from the posterior beta
  draw <- rbeta(1, alpha, beta)
  
  # choose unknown arm if draw is above theta_0
  A = (draw >= theta_0)
  
  # Update posterior
  alpha = alpha + A*(reward_vec[i] == 1)
  beta = beta + A*(reward_vec[i] == 0)
  
  X = X + A*(reward_vec[i] - theta_0)
  Q = Q + A
  
}
return(c(X/sqrt(n), Q/n))
}

#example
theta_0 = 0.05
alpha_0 = 1
beta_0 = 20
n = 10000
TS_one_arm_bernoulli(theta_0, theta_0, alpha_0, beta_0, n)
```

## Finite sample size of naive t-test
```{r}
finite_sample_size = function(n, theta_0, alpha_0, beta_0, B){
  replications = replicate(B,
                           expr = TS_one_arm_bernoulli(theta_0, theta_0, alpha_0, beta_0, n))
  sigma = sqrt(theta_0*(1-theta_0))
  rejections = (abs(replications[1,])/replications[2,] >= 1.96*sigma)
  return(mean(rejections))
}
finite_sample_size = Vectorize(finite_sample_size)

n_values = c(2500, 5000, 7500, 10000)
B = 20000
size_vs_n_t_test = finite_sample_size(n_values, theta_0, alpha_0, beta_0, B)
```

## Computing critical values for weighted average power in finite samples
```{r}
WAP_finite_sample_critical_value <- function(theta_0,
                                             n,
                                             range,
                                             M,
                                             B,
                                             alpha_0,
                                             beta_0,
                                             test_size) {
  
  sigma_0 = sqrt(theta_0*(1-theta_0))
  mu_grid <- seq(from = -range, to = range, length.out = M)
  bandit_outcomes <- replicate(B,
                               TS_one_arm_bernoulli(theta_0, theta_0, alpha_0, beta_0, n))
  
  weight_function <- function(z) {
    exp_term <- (mu_grid * z[1] - (z[2] * mu_grid^2 / 2))/sigma_0^2
    return(mean(exp(exp_term)))
  }
  
  weighted_LR_draws <- apply(bandit_outcomes, 2, weight_function)
  critical_value <- quantile(weighted_LR_draws, 1 - test_size)
  
  return(critical_value)
}
```

## Finite sample size of WAP test
```{r}
WAP_finite_sample_size <- function(theta_0,
                                   n,
                                   range,
                                   M,
                                   B,
                                   alpha_0,
                                   beta_0,
                                   critical_value_finite_n) {
  
  sigma_0 = sqrt(theta_0*(1-theta_0))
  mu_grid <- seq(from = -range, to = range, length.out = M)
  bandit_outcomes <- replicate(B,
                               TS_one_arm_bernoulli(theta_0, theta_0, alpha_0, beta_0, n))
  
  weight_function <- function(z) {
    exp_term <- (mu_grid * z[1] - (z[2] * mu_grid^2 / 2))/sigma_0^2
    return(mean(exp(exp_term)))
  }
  
  weighted_LR_draws <- apply(bandit_outcomes, 2, weight_function)
  rejections <- (weighted_LR_draws >= critical_value_finite_n)
  
  return(mean(rejections))
}
```


## Power properties 
```{r}
WAP_finite_sample_power <- function(h,
                                    n,
                                    theta_0,
                                    range,
                                    M,
                                    B,
                                    alpha_0,
                                    beta_0,
                                    critical_value) {
  
  sigma_0 = sqrt(theta_0*(1-theta_0))
  theta = theta_0 + (h/sqrt(n))
  mu_grid <- seq(from = -range, to = range, length.out = M)
  bandit_outcomes <- replicate(B,
                               TS_one_arm_bernoulli(theta, theta_0, alpha_0, beta_0, n))
  
  weight_function <- function(z) {
    exp_term <- (mu_grid * z[1] - (z[2] * mu_grid^2 / 2))/sigma_0^2
    return(mean(exp(exp_term)))
  }
  
  weighted_LR_draws <- apply(bandit_outcomes, 2, weight_function)
  power <- mean(weighted_LR_draws >= critical_value)
  
  return(power)
}

WAP_finite_sample_power = Vectorize(WAP_finite_sample_power)
```

#computing power for various sample sizes
```{r}
seq1 = seq(from = -range, to = range, length.out = 25)

B = 25000
critical_value_finite_n5k = WAP_finite_sample_critical_value(theta_0,
                                             5000,
                                             range,
                                             M,
                                             B,
                                             alpha_0,
                                             beta_0,
                                             test_size)

critical_value_finite_n10k = WAP_finite_sample_critical_value(theta_0,
                                             10000,
                                             range,
                                             M,
                                             B,
                                             alpha_0,
                                             beta_0,
                                             test_size)

critical_value_finite_n7k = WAP_finite_sample_critical_value(theta_0,
                                             7500,
                                             range,
                                             M,
                                             B,
                                             alpha_0,
                                             beta_0,
                                             test_size)

critical_value_finite_n2k = WAP_finite_sample_critical_value(theta_0,
                                             2500,
                                             range,
                                             M,
                                             B,
                                             alpha_0,
                                             beta_0,
                                             test_size)

B = 10000
power_n5K = WAP_finite_sample_power(seq1,
                        5000,
                        theta_0,
                        range,
                        M,
                        B,
                        alpha_0,
                        beta_0,
                        critical_value_finite_n5k)

power_n10K = WAP_finite_sample_power(seq1,
                        10000,
                        theta_0,
                        range,
                        M,
                        B,
                        alpha_0,
                        beta_0,
                        critical_value_finite_n10k)

power_n7K = WAP_finite_sample_power(seq1,
                        7500,
                        theta_0,
                        range,
                        M,
                        B,
                        alpha_0,
                        beta_0,
                        critical_value_finite_n7k)

power_n2K = WAP_finite_sample_power(seq1,
                        2500,
                        theta_0,
                        range,
                        M,
                        B,
                        alpha_0,
                        beta_0,
                        critical_value_finite_n2k)
```

```{r}
library(tidyverse)
library(latex2exp)

df1 = data.frame(seq = seq1,
                 n2500 = power_n2K,
                 n5000 = power_n5K,
                 n7500 = power_n7K,
                 n10000 = power_n10K,
                 power_limit = power_asymptotic)

#setwd("/Users/akarun/Library/CloudStorage/Dropbox/Optimal sequential tests/Figures")
write_csv(df1, "df1.csv")
```


# Plot figures

## Plot size
```{r}
n_values = c(2500, 5000, 7500, 10000)
#df1 = read_csv("df1.csv")
size_WAP = df1 %>% 
  select(-seq, -power_limit) %>% 
  slice(13) %>% 
  as.numeric()
size.df = data.frame(n_seq = n_values,
                     size_WAP = size_WAP,
                     size_two_sample = size_vs_n_t_test) %>%
  pivot_longer(-n_seq,
               names_to = "type",
               values_to = "size") %>%
  mutate(type = if_else(type == "size_WAP", "Proposed test", "Naive t-test")) %>%
  mutate(type = factor(type, levels = c("Proposed test", "Naive t-test")))
size.df
```

```{r}
ggplot(size.df) +
  geom_point(aes(x = n_seq,
                 y = size,
                 color = type),
             size = 3) +
  geom_line(aes(x = n_seq,
                 y = size,
                 color = type),
            linewidth = 1) +
  labs(x = "Sample size (n)", y = "Size", color = "Type of test") +
  geom_hline(yintercept = 0.05, 
             linetype = "dashed",
             linewidth = 0.8,
             color = "red") +
 # geom_hline(yintercept = asymptotic_size_t_test, 
 #             linetype = "dashed",
 #             linewidth = 0.8,
 #             color = "blue") +
  theme(#aspect.ratio = 4/3,
        text = element_text(size=13),
        axis.text = element_text(size=13),
        axis.title = element_text(size=14),
        legend.position = c(0.87, 0.4),
        #legend.title = "Type of test"
        #plot.margin=grid::unit(c(0,0,0,0), "mm")
        ) +
  ylim(0, 0.37)


#setwd("/Users/karunadusumilli/Dropbox/Optimal sequential tests/Figures")
ggsave("Size_one_armed_bandit.png")
```

## Plot power
```{r}
names(df1) = c("delta_seq", "2500", "5000", "7500", "10000", "Asymptotic power function")

power.df1 = df1 %>%
  pivot_longer(-delta_seq,
               names_to = "n",
               values_to = "Power") %>%
  mutate(n = factor(n, levels = c("2500", "5000", "7500", "10000", "Asymptotic power function")))

ggplot(power.df1, aes(x = delta_seq,
                      y = Power,
                      color = n)) +
  geom_point() +
  geom_line() +
  geom_hline(yintercept = 0.05, 
             linetype = "dashed",
             linewidth = 0.3,
             color = "blue") +
  xlab(TeX(r"(Scaled treatment effect $\sqrt{n} (\theta - \theta_0)$ )")) +
  ylab("Power") +
  theme(#aspect.ratio = 4/3,
        text = element_text(size=13),
        axis.text = element_text(size=12),
        axis.title = element_text(size=13),
        legend.position = c(0.2, 0.75)
        ) +
  scale_color_brewer(palette = "RdGy")

#setwd("/Users/akarun/Library/CloudStorage/Dropbox/Optimal sequential tests/Figures")
ggsave("Power_envelope_one_armed_bandit.png")
```